package org.example.nebula.deploy

import com.facebook.thrift.protocol.TCompactProtocol
import com.vesoft.nebula.algorithm.config.{CcConfig, CoefficientConfig, LPAConfig, LouvainConfig, PRConfig}
import com.vesoft.nebula.algorithm.lib.{ClusteringCoefficientAlgo, ConnectedComponentsAlgo, DegreeStaticAlgo, LabelPropagationAlgo, LouvainAlgo, PageRankAlgo, StronglyConnectedComponentsAlgo, TriangleCountAlgo}
import org.apache.spark.SparkConf
import org.apache.spark.rdd.RDD
import org.apache.spark.sql.expressions.Window
import org.apache.spark.sql.functions.{col, dense_rank}
import org.apache.spark.sql.{DataFrame, Row, SparkSession}
import org.apache.spark.sql.types.{StringType, StructField, StructType}
import spray.json.DefaultJsonProtocol.{IntJsonFormat, StringJsonFormat, listFormat}
import spray.json._

object RunNebulaAlgo {

    def main(args: Array[String]): Unit = {

        val sparkConf = new SparkConf()
                .set("spark.serializer", "org.apache.spark.serializer.KryoSerializer")
                .registerKryoClasses(Array[Class[_]](classOf[TCompactProtocol]))
        val spark = SparkSession
                .builder()
                .appName("deploy-nebula-algorithm")
                .master("yarn")
                .config(sparkConf)
                .getOrCreate()

        val paramJson = args(0).parseJson
        val inputType = paramJson.asJsObject.fields("type").convertTo[String]
        val algoName = paramJson.asJsObject.fields("algo").convertTo[String]
        val data = paramJson.asJsObject.fields("params").asJsObject
        val result = paramJson.asJsObject.fields("result").asJsObject

        var returnResult = spark.emptyDataFrame
        if (inputType == "file") {
            //从文本中读取数据计算
            returnResult = runAlgorithmWithFile(spark, algoName, data)
        } else if (inputType == "jdbc") {
            //从mysql数据库中读取数据计算
            returnResult = runAlgorithmWithJdbc(spark, algoName, data)
        } else if (inputType == "text") {
            //将传递的参数作为图数据进行计算
            returnResult = runAlgorithmWithText(spark, algoName, data)
        }

        handleAlgoResult(algoName, returnResult, result)

        spark.stop()
    }

    def runAlgorithmWithJdbc(spark: SparkSession, algoName: String, data: JsObject): DataFrame = {
        val url = data.fields("url").convertTo[String]
        val driver = data.fields("driver").convertTo[String]
        val user = data.fields("user").convertTo[String]
        val password = data.fields("password").convertTo[String]
        val dbtable = data.fields("dbtable").convertTo[String]
        val srcColumn = data.fields("srcColumn").convertTo[String]
        val dstColumn = data.fields("dstColumn").convertTo[String]

        println(url + "\t" + driver + "\t" + user + "\t" + password + "\t" + dbtable + "\t" + srcColumn + "\t" + dstColumn)

        //方式 1：通用的 load 方法读取
        val dataFrame: DataFrame = spark.read.format("jdbc")
                .option("url", url)
                .option("driver", driver)
                .option("user", user)
                .option("password", password)
                .option("dbtable", dbtable)
                .load()
        //    dataFrame.show()

        val df: DataFrame = dataFrame.select(srcColumn, dstColumn)
                .withColumnRenamed(srcColumn, "src")
                .withColumnRenamed(dstColumn, "dst")
        //    df.show()

        val result: DataFrame = calculateData(df, spark, algoName)

        result
    }

    /**
     * 计算文件类型的输入
     */
    def runAlgorithmWithFile(spark: SparkSession, algoName: String, data: JsObject): DataFrame = {
        val dataName = data.fields("dataPath").convertTo[String]

        println(algoName + "\t" + dataName)

        val dataCsv = "file:///" + dataName
        val df: DataFrame = spark.read
                .option("header", true)
                .option("delimiter", ",")
                .csv(dataCsv)

        val result: DataFrame = calculateData(df, spark, algoName)

        result
    }

    def handleAlgoResult(algoName: String, result: DataFrame, resultType: JsObject): Unit = {
        val returnType = resultType.fields("type").convertTo[String]

        if (returnType == "top") {
            val number = resultType.fields("number").convertTo[Int]
            if (algoName == "Degree") {
                val sortedDF = result.orderBy(col("degree").desc)
                sortedDF.show()
                sortedDF.take(number).foreach(println)
            } else if (algoName == "PageRank") {
                val sortedDF = result.orderBy(col("pagerank").desc)
                sortedDF.show()
                sortedDF.take(number).foreach(println)
            } else if (algoName == "ClusteringCoefficient") {
                val sortedDF = result.orderBy(col("clustercoefficient").desc)
                sortedDF.show()
                sortedDF.take(number).foreach(println)
            } else if (algoName == "Louvain") {
                val sortedDF = result.orderBy(col("louvain").desc)
                sortedDF.show()
                sortedDF.take(number).foreach(println)
            } else if (algoName == "LPA") {
                val sortedDF = result.orderBy(col("lpa").desc)
                sortedDF.show()
                sortedDF.take(number).foreach(println)
            } else if (algoName == "SCC") {
                val sortedDF = result.orderBy(col("scc").desc)
                sortedDF.show()
                sortedDF.take(number).foreach(println)
            } else if (algoName == "TriangleCount") {
                val sortedDF = result.orderBy(col("trianglecount").desc)
                sortedDF.show()
                sortedDF.take(number).foreach(println)
            } else if(algoName == "cc") {
                val sortedDF = result.orderBy(col("cc").desc)
                sortedDF.show()
                sortedDF.take(number).foreach(println)
            }

        } else if (returnType == "file") {
            val resultPath = resultType.fields("resultPath").convertTo[String]

            val resultCsv = "file:///" + resultPath
            result.repartition(1).write.option("header", true).mode("overwrite").csv(resultCsv)
        }
    }


    def runAlgorithmWithText(spark: SparkSession, algoName: String, data: JsObject): DataFrame = {
        val srcColumn = data.fields("src").convertTo[List[String]]
        val dstColumn = data.fields("dst").convertTo[List[String]]
        val rdd1: RDD[(String, String)] = spark.sparkContext.makeRDD(srcColumn.zip(dstColumn))

        // 推导数据的结构
        val schema = StructType(Seq(
            StructField("src", StringType, nullable = false),
            StructField("dst", StringType, nullable = false)
        ))

        val df: DataFrame = spark.createDataFrame(rdd1.map(Row.fromTuple), schema)
        df.show()

        val result: DataFrame = calculateData(df, spark, algoName)

        result
    }

    def calculateData(df: DataFrame, spark: SparkSession, algoName: String): DataFrame = {
        val srcIdDF: DataFrame = df.select("src").withColumnRenamed("src", "id")
        val dstIdDF: DataFrame = df.select("dst").withColumnRenamed("dst", "id")
        val idDF = srcIdDF.union(dstIdDF).distinct()
        val encodeId = idDF.withColumn("encodedId", dense_rank().over(Window.orderBy("id")))

        encodeId.show()
        println("共" + df.count() + "条边")
        println("共" + encodeId.count() + "个节点")

        val srcJoinDF = df
                .join(encodeId)
                .where(col("src") === col("id"))
                .drop("src")
                .drop("id")
                .withColumnRenamed("encodedId", "src")
        srcJoinDF.cache()
        val dstJoinDF = srcJoinDF
                .join(encodeId)
                .where(col("dst") === col("id"))
                .drop("dst")
                .drop("id")
                .withColumnRenamed("encodedId", "dst")

        val encodedDF = dstJoinDF.select("src", "dst")

        //调用算法计算
        var algoResult: DataFrame = spark.emptyDataFrame
        if (algoName == "PageRank") {
            val pageRankConfig = PRConfig(3, 0.85)
            algoResult = PageRankAlgo.apply(spark, encodedDF, pageRankConfig, false)
        } else if (algoName == "ClusteringCoefficient") {
            val localClusteringCoefficientConfig = new CoefficientConfig("local")
            algoResult = ClusteringCoefficientAlgo.apply(spark, encodedDF, localClusteringCoefficientConfig)
        } else if (algoName == "Degree") {
            algoResult = DegreeStaticAlgo.apply(spark, encodedDF).select("_id", "degree")
        } else if (algoName == "Louvain") {
            val louvainConfig = LouvainConfig(10, 5, 0.5)
            algoResult = LouvainAlgo.apply(spark, encodedDF, louvainConfig, false)
        } else if (algoName == "LPA") {
            val lpaConfig = LPAConfig(10)
            algoResult = LabelPropagationAlgo.apply(spark, encodedDF, lpaConfig, false)
        } else if (algoName == "SCC") {
            val ccConfig = CcConfig(Int.MaxValue)
            algoResult = StronglyConnectedComponentsAlgo.apply(spark, encodedDF, ccConfig, false)
        } else if (algoName == "TriangleCount") {
            algoResult = TriangleCountAlgo.apply(spark, encodedDF)
        } else if (algoName == "WCC") {
            val ccConfig = CcConfig(20)
            algoResult = ConnectedComponentsAlgo.apply(spark, encodedDF, ccConfig, false)
        }
        algoResult.show()

        //将id映射回来
        val decodedPr = encodeId
                .join(algoResult)
                .where(col("encodedId") === col("_id"))
                .drop("encodedId")
                .drop("_id")
        decodedPr.show()

        decodedPr
    }
}
